{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fgfyb_2c9WL4"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/Trusted-AI/AIF360/blob/main/examples/sklearn/demo_grid_search_reduction_regression_sklearn.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_uAgvnjw71C_"
      },
      "source": [
        "# Sklearn compatible Grid Search for regression\n",
        "\n",
        "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n",
        "the candidates searched. For regression it uses the same priniciple to return a deterministic regressor with the lowest empirical error subject to the constraint of bounded group loss. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4Ad7nC129i8f",
        "outputId": "0ce37c29-d5a5-4682-968e-bb33cea9de19"
      },
      "outputs": [],
      "source": [
        "#Install aif360\n",
        "#Install Reductions from Fairlearn\n",
        "!pip install aif360[Reductions]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sivw3vma71DE"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.compose import TransformedTargetRegressor\n",
        "from sklearn.linear_model import LinearRegression\n",
        "from sklearn.metrics import mean_absolute_error\n",
        "from sklearn.preprocessing import MinMaxScaler\n",
        "\n",
        "from aif360.sklearn.datasets import fetch_lawschool_gpa\n",
        "from aif360.sklearn.inprocessing import GridSearchReduction\n",
        "from aif360.sklearn.metrics import difference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMTJ62DZ71DF"
      },
      "source": [
        "### Loading data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KgqbgiIY71DG"
      },
      "source": [
        "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
        "\n",
        "For example, we can easily load the law school gpa dataset from tempeh with the following line:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        },
        "id": "7Xqg6Yn771DG",
        "outputId": "00fda049-0c6a-4c86-9ee8-f49b9e2d5161"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th>race</th>\n",
              "      <th>lsat</th>\n",
              "      <th>ugpa</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th>gender</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0.0</th>\n",
              "      <th>1</th>\n",
              "      <td>0.0</td>\n",
              "      <td>38.0</td>\n",
              "      <td>3.3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th rowspan=\"4\" valign=\"top\">1.0</th>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>4.0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>3.9</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>45.0</td>\n",
              "      <td>3.3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>1.0</td>\n",
              "      <td>39.0</td>\n",
              "      <td>2.5</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "             race  lsat  ugpa\n",
              "race gender                  \n",
              "0.0  1        0.0  38.0   3.3\n",
              "1.0  0        1.0  34.0   4.0\n",
              "     0        1.0  34.0   3.9\n",
              "     0        1.0  45.0   3.3\n",
              "     1        1.0  39.0   2.5"
            ]
          },
          "execution_count": 2,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "X_train, y_train = fetch_lawschool_gpa(\"train\", numeric_only=True, dropcols=\"gender\")\n",
        "X_test, y_test = fetch_lawschool_gpa(\"test\", numeric_only=True, dropcols=\"gender\")\n",
        "X_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mgjzey6y71DI"
      },
      "source": [
        "We normalize the continuous values, making sure to propagate column names associated with protected attributes, information necessary for grid search reduction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        },
        "id": "7CxY3K8i71DI",
        "outputId": "c14dec8a-8a45-4992-af71-fd6cc5761cc0"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th>race</th>\n",
              "      <th>lsat</th>\n",
              "      <th>ugpa</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>race</th>\n",
              "      <th>gender</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0.0</th>\n",
              "      <th>1</th>\n",
              "      <td>0.0</td>\n",
              "      <td>0.729730</td>\n",
              "      <td>0.825</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th rowspan=\"4\" valign=\"top\">1.0</th>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>0.621622</td>\n",
              "      <td>1.000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>0.621622</td>\n",
              "      <td>0.975</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1.0</td>\n",
              "      <td>0.918919</td>\n",
              "      <td>0.825</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>1.0</td>\n",
              "      <td>0.756757</td>\n",
              "      <td>0.625</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "             race      lsat   ugpa\n",
              "race gender                       \n",
              "0.0  1        0.0  0.729730  0.825\n",
              "1.0  0        1.0  0.621622  1.000\n",
              "     0        1.0  0.621622  0.975\n",
              "     0        1.0  0.918919  0.825\n",
              "     1        1.0  0.756757  0.625"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "scaler = MinMaxScaler()\n",
        "\n",
        "X_train  = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns, index=X_train.index)\n",
        "X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns, index=X_test.index)\n",
        "\n",
        "X_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lukGdCjv71DJ"
      },
      "source": [
        "### Running metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9LheHu-71DK"
      },
      "source": [
        "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data. We drop the protective attribule columns so that they are not used in the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y3mrAFw471DL",
        "outputId": "d7bd400b-2fc1-426b-baab-93139fb48739"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.7400826321650612"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tt = TransformedTargetRegressor(LinearRegression(), transformer=scaler)\n",
        "tt = tt.fit(X_train.drop([\"race\"], axis=1), y_train)\n",
        "y_pred = tt.predict(X_test.drop([\"race\"], axis=1))\n",
        "lr_mae = mean_absolute_error(y_test, y_pred)\n",
        "lr_mae"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M2cLBv8q71DL"
      },
      "source": [
        "We can assess how the mean absolute error differs across groups simply"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "F97LRCdf71DM",
        "outputId": "fe6f95a7-fdb9-4a47-bd39-284ef726f983"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "0.20392590525744636"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "lr_mae_diff = difference(mean_absolute_error, y_test, y_pred, prot_attr=\"race\")\n",
        "lr_mae_diff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ADb1vbQ71DM"
      },
      "source": [
        "### Grid Search"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkoNE09071DM"
      },
      "source": [
        "Reuse the base model for the candidate regressors. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "TxejOHFK71DN"
      },
      "outputs": [],
      "source": [
        "estimator = TransformedTargetRegressor(LinearRegression(), transformer=scaler)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WyqzfrL271DN"
      },
      "source": [
        "Search for the best regressor and observe mean absolute error. Grid search for regression uses \"BoundedGroupLoss\" to specify using bounded group loss for its constraints. Accordingly we need to specify a loss function, like \"Absolute.\" Other options include \"Square\" and \"ZeroOne.\" When the loss is \"Absolute\" or \"Square\" we also specify the expected range of the y values in min_val and max_val. For details on the implementation of these loss function see the fairlearn library here https://github.com/fairlearn/fairlearn/blob/main/fairlearn/reductions/_moments/bounded_group_loss.py."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B63yaCCf71DN",
        "outputId": "e3e2c851-b3f0-4296-e868-a46b78899364"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.7622719376746614\n"
          ]
        }
      ],
      "source": [
        "np.random.seed(0) #need for reproducibility\n",
        "grid_search_red = GridSearchReduction(prot_attr=\"race\",\n",
        "                                      estimator=estimator,\n",
        "                                      constraints=\"BoundedGroupLoss\",\n",
        "                                      loss=\"Absolute\",\n",
        "                                      min_val=y_train.min(),\n",
        "                                      max_val=y_train.max(),\n",
        "                                      grid_size=10,\n",
        "                                      drop_prot_attr=True)\n",
        "grid_search_red.fit(X_train, y_train)\n",
        "gs_pred = grid_search_red.predict(X_test)\n",
        "gs_mae = mean_absolute_error(y_test, gs_pred)\n",
        "print(gs_mae)\n",
        "\n",
        "#Check if mean absolute error is comparable\n",
        "assert abs(gs_mae-lr_mae) < 0.08"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WPd4i-Zj71DO",
        "outputId": "d40bcb93-fcb1-405b-ffc3-a72ef29157e1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.06122151904963535\n"
          ]
        }
      ],
      "source": [
        "gs_mae_diff = difference(mean_absolute_error, y_test, gs_pred, prot_attr=\"race\")\n",
        "print(gs_mae_diff)\n",
        "\n",
        "#Check if difference decreased\n",
        "assert gs_mae_diff < lr_mae_diff"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.9.7 ('aif360')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    },
    "vscode": {
      "interpreter": {
        "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
